from asyncio.constants import SENDFILE_FALLBACK_READBUFFER_SIZE
from decimal import DivisionImpossible
from sqlite3 import Timestamp, TimestampFromTicks
from ryu.base import app_manager
from ryu.controller import ofp_event
from ryu.controller.handler import CONFIG_DISPATCHER, MAIN_DISPATCHER
from ryu.controller.handler import set_ev_cls
from ryu.ofproto import ofproto_v1_3
from ryu.lib.packet import packet
from ryu.lib.packet import ethernet, tcp, ipv4
from ryu.lib.packet import ether_types

import numpy as np
import time

class Eftest(app_manager.RyuApp):
    # Initialize
    OFP_VERSIONS = [ofproto_v1_3.OFP_VERSION]

    def __init__(self, *args, **kwargs):
        super(Eftest, self).__init__(*args, **kwargs)
        self.mac_to_port = {}

        self.flow_monitor_fieldinfo = [('srcIP', 'U16'), ('dstIP', 'U16'), 
                ('srcPort', 'uint'), ('dstPort', 'uint'), ('seq', 'uint'),
                ('ack', 'uint'), ('tcpCsum', 'uint'), ('trace', 'object'),
                ('portOut', 'object'), ('timestamp', 'object'), ('rawData', 'object')]
        self.flow_monitor = np.array([], dtype=self.flow_monitor_fieldinfo)

        self.probePkt = False

        self.topology_monitor = {
            1: {
                1: 2,
                2: 3,
                3: 4
            },
            2: {
                1: 1,
                2: 0,
                3: 0
            },
            3: {
                1: 1,
                2: 0,
                3: 0,
                4: 0
            },
            4: {
                1: 1,
                2: 0,
                3: 0,
                4: 0
            }
        }

        self.latency_monitor = {}

        self.timeout = 3

    # Install default flow option
    @set_ev_cls(ofp_event.EventOFPSwitchFeatures, CONFIG_DISPATCHER)
    def switch_features_handler(self, ev):
        datapath = ev.msg.datapath
        ofproto = datapath.ofproto
        parser = datapath.ofproto_parser

        match = parser.OFPMatch()
        actions = [parser.OFPActionOutput(ofproto.OFPP_CONTROLLER,
                                          ofproto.OFPCML_NO_BUFFER)]
        self.add_flow(datapath, 0, match, actions)

    # Customized function to add flow.
    def add_flow(self, datapath, priority, match, actions, buffer_id=None):
        ofproto = datapath.ofproto
        parser = datapath.ofproto_parser

        inst = [parser.OFPInstructionActions(ofproto.OFPIT_APPLY_ACTIONS,
                                             actions)]
        if buffer_id:
            mod = parser.OFPFlowMod(datapath=datapath, buffer_id=buffer_id,
                                    priority=priority, match=match,
                                    instructions=inst)
        else:
            mod = parser.OFPFlowMod(datapath=datapath, priority=priority,
                                    match=match, instructions=inst)
        datapath.send_msg(mod)

    @set_ev_cls(ofp_event.EventOFPPacketIn, MAIN_DISPATCHER)
    def _packet_in_handler(self, ev):
        # Truncated packet error log.
        if ev.msg.msg_len < ev.msg.total_len:
            self.logger.debug("packet truncated: only %s of %s bytes",
                              ev.msg.msg_len, ev.msg.total_len)
        msg = ev.msg
        datapath = msg.datapath
        ofproto = datapath.ofproto
        parser = datapath.ofproto_parser
        in_port = msg.match['in_port']

        pkt = packet.Packet(msg.data)
        eth = pkt.get_protocols(ethernet.ethernet)[0]

        if eth.ethertype == ether_types.ETH_TYPE_LLDP:
            # Ignore lldp packet
            return
        dst = eth.dst
        src = eth.src

        dpid = format(datapath.id, "d").zfill(16)
        self.mac_to_port.setdefault(dpid, {})
        self.latency_monitor.setdefault(datapath.id, {})

        # Learn a mac address to avoid FLOOD next time.
        self.mac_to_port[dpid][src] = in_port

        if dst in self.mac_to_port[dpid]:
            out_port = self.mac_to_port[dpid][dst]
        else:
            out_port = ofproto.OFPP_FLOOD

        actions = [parser.OFPActionOutput(out_port)]

        # Send data out
        data = None
        if msg.buffer_id == ofproto.OFP_NO_BUFFER:
            data = msg.data

        out = parser.OFPPacketOut(datapath=datapath, buffer_id=msg.buffer_id,
                                  in_port=in_port, actions=actions, data=data)
        datapath.send_msg(out)

        # Ignore flood packet
        if out_port == ofproto.OFPP_FLOOD:
            return

        # Packet analyse
        pkt = packet.Packet(msg.data)
        tcpHead = pkt.get_protocol(tcp.tcp)
        if tcpHead == None:
            # Ignore non TCP packets
            return
        ipHead = pkt.get_protocol(ipv4.ipv4)
        
        # Collect data
        srcIP = ipHead.src
        dstIP = ipHead.dst
        srcPort = tcpHead.src_port
        dstPort = tcpHead.dst_port
        seq = tcpHead.seq
        ack = tcpHead.ack
        tcpCsum = tcpHead.csum
        rawData = pkt.data
        currNode = datapath.id
        portOut = out_port
        timestamp = time.time()

        # Register packet
        filt = (
                self.flow_monitor['srcIP'] == srcIP
                ) & (
                self.flow_monitor['dstIP'] == dstIP
                ) & (
                self.flow_monitor['srcPort'] == srcPort
                ) & (
                self.flow_monitor['dstPort'] == dstPort
                ) & (
                self.flow_monitor['tcpCsum'] == tcpCsum
                )
        if self.flow_monitor[filt].size == 0:
            newEntry = np.array([(srcIP, dstIP, srcPort, dstPort, seq, ack, tcpCsum,
                    [currNode], [portOut], [timestamp], rawData)], dtype=self.flow_monitor_fieldinfo)
            self.flow_monitor = np.concatenate((self.flow_monitor, newEntry))
        else:
            self.flow_monitor[filt]['trace'][0].append(currNode)
            self.flow_monitor[filt]['timestamp'][0].append(timestamp)
            self.flow_monitor[filt]['portOut'][0].append(portOut)

        # Register latency
        filt = []
        for i in self.flow_monitor:
            if self.topology_monitor[i['trace'][-1]][i['portOut'][-1]] == 0:
                filt.append(False)
            else:
                filt.append(True)
        mask = np.array(filt, dtype=bool)

        for i in self.flow_monitor[np.logical_not(mask)]:
            trace = i['trace']
            timeline = i['timestamp']
            endNode = trace.pop()
            endTime = timeline.pop()
            while len(trace) != 0:
                startNode = trace.pop()
                startTime = timeline.pop()
                if endNode in self.latency_monitor[startNode]:
                    if len(self.latency_monitor[startNode][endNode]) >= 100:
                        self.latency_monitor[startNode][endNode] = self.latency_monitor[startNode][endNode][-100:]
                    self.latency_monitor[startNode][endNode].append(endTime-startTime)
                else:
                    self.latency_monitor[startNode][endNode] = [endTime-startTime]
                endNode = startNode
                endTime = startTime

        # Clear finished packet registry
        self.flow_monitor = self.flow_monitor[mask]

        # Loss packet detection
        filt = []
        for i in self.flow_monitor:
            if timestamp - i['timestamp'][-1] > self.timeout:
                filt.append(False)
                dpidSrc = i['trace'][-1]
                portTo = i['portOut'][-1]

                dpidDst = self.topology_monitor[dpidSrc][portTo]
                if dpidDst != 0:
                    if dpidDst in self.latency_monitor[dpidSrc]:
                        self.latency_monitor[dpidSrc][dpidDst].append(-1)
                    else:
                        self.latency_monitor[dpidSrc][dpidDst] = [-1]

                lagData = np.array(self.latency_monitor[dpidSrc][dpidDst])
                pktCnt = lagData.size
                lossCnt = lagData[lagData == -1].size
                avgLag = np.sum(lagData[lagData!=-1])/(pktCnt-lossCnt)
                lossRT = lossCnt/pktCnt
                if (lossCnt < 50):
                    print('Packet loss detected: dpid:(%d) to dpid:(%d) OutPort:(%d) LossRT:%d AvgLag:%d ms'
                            %(dpidSrc, dpidDst, portTo, lossCnt/pktCnt*100, avgLag*1000))
                else:
                    print('WARNING:Severe packet loss detected: dpid:(%d) to dpid:(%d) OutPort:(%d) LossRT:%d AvgLag:%d ms'
                            %(dpidSrc, dpidDst, portTo, lossCnt/pktCnt*100, avgLag*1000))
                    print('Sending out Probe packets...')
                    self.probePkt = True
            else:
                filt.append(True)
        mask = np.array(filt, dtype=bool)

        # Clear lost packet registry
        self.flow_monitor = self.flow_monitor[mask]
        